#!/bin/bash

if [ $# -lt 2 ]; then
    echo "Usage: $0 <model_name> <peft_path> [data] [decoding] [--subset_size SUBSET_SIZE] [--k K] [--exp_num EXP_NUM] [--method METHOD] [--dp_choice DP_CHOICE] [--emb EMB] [--metric METRIC]"
    exit 1
fi

MDL=$1
MDLB=$2
data=${3:-gsm8k}
decoding=${4:-greedy}

subset_size=96
k=4
exp_num=5
method="knn"
dp_choice=""
emb="all-roberta-large-v1"
metric="cosine_similarity"
apply_chat_template=false

shift 4
while [ $# -gt 0 ]; do
    case "$1" in
        --subset_size)
            subset_size=$2
            shift 2
            ;;
        --k)
            k=$2
            shift 2
            ;;
        --exp_num)
            exp_num=$2
            shift 2
            ;;
        --method)
            method=$2
            shift 2
            ;;
        --dp_choice)
            dp_choice=$2
            shift 2
            ;;
        --emb)
            emb=$2
            shift 2
            ;;
        --metric)
            metric=$2
            shift 2
            ;;
        --apply_chat_template)
            apply_chat_template=$2
            shift 2
            ;;
        *)
            echo "Unknown parameter: $1"
            exit 1
            ;;
    esac
done

echo "Model: $MDL"
echo "Peft Path: $MDLB"
echo "Data: $data"
echo "Decoding: $decoding"
echo "Subset Size: $subset_size"
echo "K: $k"
echo "Exp Num: $exp_num"
echo "Method: $method"
echo "DP Choice: $dp_choice"
echo "Embedding: $emb"
echo "Metric: $metric"
echo "Apply Chat Template: $apply_chat_template"

MDLBT=${MDLB}-${data}-${method}-${metric}-${exp_num}-${subset_size}
if [[ $decoding != greedy ]]; then
    MDLBT=${MDLBT}-${decoding}
fi

output=results/test/$MDLBT/$emb

read DECODING_FLAGS < <(python decoding_args_helper.py $decoding)
echo "Decoding Flags: $DECODING_FLAGS"

if [ "$apply_chat_template" = true ]; then
    python -u fast_inference.py \
        --model_path "$MDL" \
        --dataset "$data" \
        --prompt_template_style "$data" \
        --output "$output" \
        --subset_size "$subset_size" \
        --k "$k" \
        --exp_num "$exp_num" \
        --method "$method" \
        ${dp_choice:+--dp_choice "$dp_choice"} \
        --emb "$emb" \
        --metric "$metric" \
        $DECODING_FLAGS --freq 16 \
        --apply_chat_template
else
    python -u fast_inference.py \
        --model_path "$MDL" \
        --dataset "$data" \
        --prompt_template_style "$data" \
        --output "$output" \
        --subset_size "$subset_size" \
        --k "$k" \
        --exp_num "$exp_num" \
        --method "$method" \
        ${dp_choice:+--dp_choice "$dp_choice"} \
        --emb "$emb" \
        --metric "$metric" \
        $DECODING_FLAGS --freq 16
fi
